from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import pandas as pd

# 1. 准备数据
# 假设你已经准备好了特征矩阵 X 和标签向量 y
df = pd.read_csv("illness.csv")
df["class"] = df["class"].map({"Abnormal": 0, "Normal": 1}).fillna(-1)

# print(df.isnull().sum())

X = df.iloc[:, :-1]
y = df.iloc[:, -1]

# 2. 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 4. 创建模型实例
model = LogisticRegression()

# 5. 拟合模型
model.fit(X_train, y_train)

weights = model.coef_
print("Omega (weights):")
print(weights)

# 6. 进行预测
y_pred = model.predict(X_test)

# 7. 评估模型
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

arr = [
    0.07142857142857142,
    0.13793103448275862,
    0.13333333333333333,
    0.19354838709677416,
    0.25,
    0.303030303030303,
    0.35294117647058826,
    0.39999999999999997,
    0.4444444444444444,
    0.43243243243243246,
    0.4736842105263157,
    0.5128205128205128,
    0.5,
    0.5365853658536585,
    0.5714285714285714,
    0.6046511627906976,
    0.5909090909090909,
    0.5777777777777777,
    0.6086956521739131,
    0.6382978723404256,
    0.6666666666666666,
    0.6530612244897959,
    0.6399999999999999,
    0.627450980392157,
    0.6153846153846153,
    0.6037735849056604,
    0.5925925925925926,
    0.5818181818181818,
    0.5714285714285714,
    0.5614035087719299,
    0.5517241379310345,
    0.5423728813559322,
    0.5333333333333333,
    0.5245901639344261,
    0.5161290322580645,
    0.5079365079365079,
    0.5,
    0.49230769230769234,
    0.4848484848484849,
    0.4776119402985075,
    0.47058823529411764,
    0.463768115942029,
    0.45714285714285713,
    0.4507042253521127,
    0.4444444444444444,
    0.4383561643835616,
    0.4324324324324324,
    0.4266666666666667,
    0.42105263157894735,
    0.41558441558441556,
    0.4102564102564103,
    0.40506329113924056,
    0.4,
    0.3950617283950617,
    0.39024390243902435,
    0.3855421686746988,
    0.38095238095238093,
    0.3764705882352941,
    0.37209302325581395,
    0.36781609195402304,
    0.36363636363636365,
    0.3595505617977528,
    0.3555555555555555,
    0.3516483516483516,
    0.34782608695652173,
    0.3440860215053763,
    0.3404255319148936,
    0.33684210526315783,
    0.3333333333333333,
    0.32989690721649484,
    0.32653061224489793,
    0.3232323232323232,
    0.32,
    0.31683168316831684,
    0.3137254901960784,
    0.3106796116504854,
    0.3076923076923077,
    0.32380952380952377,
    0.33962264150943394,
    0.35514018691588783,
    0.37037037037037035,
    0.38532110091743116,
    0.4,
    0.4144144144144144,
    0.4285714285714286,
    0.4424778761061947,
    0.456140350877193,
    0.46956521739130436,
    0.46551724137931033,
    0.4615384615384615,
    0.4576271186440678,
    0.45378151260504207,
    0.45000000000000007,
    0.44628099173553726,
    0.4426229508196721,
    0.43902439024390244,
    0.435483870967742,
    0.43199999999999994,
    0.42857142857142855,
    0.4251968503937008,
    0.421875,
    0.4186046511627907,
    0.4153846153846154,
    0.4122137404580153,
    0.40909090909090906,
    0.40601503759398494,
    0.40298507462686567,
    0.4,
    0.39705882352941174,
    0.39416058394160586,
    0.391304347826087,
    0.3884892086330935,
    0.38571428571428573,
    0.38297872340425526,
]

print(min(arr), max(arr))
print(sum(arr) / len(arr))
